#pragma once

#include <memory>
#include <string>
#include <mutex>
#include <vector>
#include <algorithm>
#include "ascend_util.h"
#include "triton/backend/backend_model.h"
#include "MxBase/E2eInfer/Model/Model.h"
#include "MxBase/E2eInfer/Tensor/Tensor.h"
#include "MxBase/E2eInfer/DataType.h"
#include "MxBase/E2eInfer/GlobalInit/GlobalInit.h"

namespace triton { namespace backend { namespace ascend {

//
// ModelState
//
// State associated with a model that is using this backend. An object
// of this class is created and associated with each
// TRITONBACKEND_Model.
//
class ModelState : public BackendModel {
public:
    static TRITONSERVER_Error* Create(
        TRITONBACKEND_Model* triton_model, ModelState** state);

    virtual ~ModelState() = default;

    // Load a TorchScript model using 'artifact_name' as the name for the
    // TorchScript file. Return in 'model_path' the full path to the
    // TorchScript file, return in 'torch_model' the Torch Module
    // representing the model.
    TRITONSERVER_Error* LoadModel(const std::string& artifact_name, const int32_t device,
        std::string* model_path, std::shared_ptr<MxBase::Model>* ascend_model);
    
    DynamicInfo GetDynamicInfo() const { return dynamic_info_;}
    DataFormat GetInputFormat() const { return input_format_;}
    std::vector<DataFormat> GetOutputFormat() const { return output_format_;}
    
    ModelStrideInfo GetModelStrideInfo() const { return model_stride_info_;}

private:
    ModelState(TRITONBACKEND_Model* triton_model);

    TRITONSERVER_Error* AutoCompleteConfig();
    TRITONSERVER_Error* ParseParameters();
    TRITONSERVER_Error* GetModelDynamicInfo(std::shared_ptr<MxBase::Model>* ascend_model);
    TRITONSERVER_Error* GetModelDynamicType(std::vector<std::vector<int64_t>>& input_shape,
        std::vector<std::vector<uint64_t>>& input_gear);
    TRITONSERVER_Error* GetModelDynamicShape(std::vector<std::vector<uint64_t>>& input_gear);

    TRITONSERVER_Error* GetModelHWStride(std::shared_ptr<MxBase::Model>* ascend_model);
    TRITONSERVER_Error* GetModelInferOutputShape(std::shared_ptr<MxBase::Model>* ascend_model,
        std::vector<uint32_t>& input_gear, std::vector<uint32_t>& output_gear);

private:
    bool is_initialized_ = false;
    // 输入的format都是一致的,只记录一次
    DataFormat input_format_;
    // 输入的format由用户给定,记录多个
    std::vector<DataFormat> output_format_;
    DynamicInfo dynamic_info_ {};

    ModelStrideInfo model_stride_info_;
    std::mutex model_mutex_;
};

}}};